{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5f21e08c-259d-4099-b1ed-b782bf94be05",
   "metadata": {},
   "source": [
    "Replace `key` with your openai key (do not use quotations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7f45bebd-fe6e-41a8-a6e1-bf56d765463e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: OPENAI_API_KEY=key\n"
     ]
    }
   ],
   "source": [
    "%env OPENAI_API_KEY=key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7784d6c8-e28a-494b-b97f-374fcc94213f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import yaml\n",
    "from openai import OpenAI\n",
    "client = OpenAI()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6de55bbf-b978-478c-a8f0-db412bc52e9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_prompts_sliders(slider_query, \n",
    "                             file_name_to_save=None, \n",
    "                             temperature=0.2, \n",
    "                             max_tokens=256, \n",
    "                             frequency_penalty=0.0,\n",
    "                             model=\"gpt-4-turbo-preview\",\n",
    "                             verbose=False, \n",
    "                             save=True):\n",
    "    '''\n",
    "    A function to automatically build prompts for text sliders using GPT4 (or any other openAI model). \n",
    "    \n",
    "    Inputs\n",
    "    ------\n",
    "    slider_query (str): A natural language query describing the slider effects the user desired (eg: \"I want to make people older\")\n",
    "    file_name_to_save (str) (optional): a full name of the yaml file a user desires. If left as None, a name will be chosen by GPT\n",
    "    temperature (float) (optional): GPT temperature parameter (use smaller values for less randomness)\n",
    "    max_tokens (int) (optional): GPT output token limit\n",
    "    frequency_penalty (float) (optional): GPT frequency penalty\n",
    "    model (str) (optional): The model class from openAI. By default uses GPT-4-Turbo\n",
    "    verbose (bool) (optional): A flag to print intermediate responses by GPT\n",
    "    save (bool) (optional): A flag to save the prompts to a destination path\n",
    "    '''\n",
    "    gpt_assistant_prompt =  '''You are an expert in prompting text-image generation models. Given a concept to edit, your task is to generate 4 detailed prompts.\n",
    "                            1. Target prompt: a prompt that describes the target class which the concept edit is intended to modify (for example, to edit the concept \"professional\" the target concept is \"person\". Leave it empty if the target concept is too large. For example if user asks for their generations to be more futuristic, since all the images have to be edited, just leave the target \"\"\n",
    "                            2. Positive prompt: a detailed prompt that describes the extreme positive end of the edit concept with the target concept included (for example, \"person, very professional, blazer, neat, organized)\"\n",
    "                            3. Negative prompt: a detailed prompt that describes the extreme negative end of the edit concept with the target concept included (for example, \"person, non-professional, ragidy, unkempt\"). This is optional, you can leave it \"\" if there is no obvious negative prompt.\n",
    "                            4. Preservation prompt: a prompt (must be comma separated) that describes any concepts except the ones to edit that should be preserved when making the edit without the target concept included (for example, \"white, black, indian, asian, hispanic; male, female\" as the race or gender of a person may be changed when we edit the professionalism.). This should not include edit concepts and should not include any of the positive or negative concepts. if there are no obvious entanglement issues with the edit, leave the prompt \"\"\n",
    "                            make preservation prompt comma seperated for each class of perservation. For example if you want to preserve both race and gender, then give something like \"white race, black race, indian race, asian race; male, female\"\n",
    "\n",
    "                            All the prompts must be strictly string type. Be specific. Do not use any alphanumeric symbols.\n",
    "                            \n",
    "                            This is an example template for your response when asked to generate prompts for making people smile:\n",
    "                            Target: person\n",
    "                            Positive: person, smiling, happy face, big smile\n",
    "                            Negative: person, frowning, grumpy, sad\n",
    "                            Preservation: white, black, indian, asian, hispanic ; male, female\n",
    "                            Name: person_age_GPT\n",
    "                            \n",
    "                            Here is another example template for your response when asked - \"I want to make images more detailed\":\n",
    "                            Target: \n",
    "                            Positive: highly detailed, intricate patterns, fine textures, realistic shading\n",
    "                            Negative: simplistic, minimalistic, abstract, rough outlines\n",
    "                            Preservation: \n",
    "                            Name: detailed_GPT\n",
    "                            '''\n",
    "    gpt_user_prompt = slider_query\n",
    "    gpt_prompt = gpt_assistant_prompt, gpt_user_prompt\n",
    "    message=[{\"role\": \"assistant\", \"content\": gpt_assistant_prompt}, {\"role\": \"user\", \"content\": gpt_user_prompt}]\n",
    "    \n",
    "    response = client.chat.completions.create(\n",
    "        model= model,\n",
    "        messages = message,\n",
    "        temperature=temperature,\n",
    "        max_tokens=max_tokens,\n",
    "        frequency_penalty=frequency_penalty\n",
    "    )\n",
    "    content = response.choices[0].message.content\n",
    "    if verbose:\n",
    "        print(content)\n",
    "    prompts = content.splitlines()\n",
    "    result = {}\n",
    "    result['target'] = \"\"\n",
    "    result['positive'] = \"\"\n",
    "    result['unconditional'] = \"\"\n",
    "    result['neutral'] = \"\"\n",
    "    for prompt in prompts:\n",
    "        key = prompt.split(':')\n",
    "        if key[0].lower().strip() == 'preservation':\n",
    "            final_attributes = []\n",
    "            attributes = key[1].split(';')\n",
    "            for attribute in attributes:\n",
    "                if len(attribute.strip()) == 0:\n",
    "                    continue\n",
    "                final_attributes.append(attribute.strip().split(','))\n",
    "        elif key[0].lower().strip() == 'name':\n",
    "            name = key[1].strip()\n",
    "    for prompt in prompts:\n",
    "        key = prompt.split(':')\n",
    "        if len(key)!=2:\n",
    "            continue\n",
    "        if key[0].lower().strip() == 'target':\n",
    "            result['target'] = key[1].strip()\n",
    "        elif key[0].lower().strip() == 'positive':\n",
    "            result['positive'] = key[1].strip()\n",
    "        elif key[0].lower().strip() == 'negative':\n",
    "            result['unconditional'] = key[1].strip()\n",
    "    result['neutral'] = result['target']\n",
    "    results = [result]\n",
    "    \n",
    "    for attribute_class in final_attributes:\n",
    "        results_final  = []\n",
    "        for attribute in attribute_class:\n",
    "            for result in results:\n",
    "                r = {}\n",
    "                for key in result.keys():\n",
    "                    r[key] = attribute.strip() + f' {result[key].strip()}'\n",
    "                    r[key] = r[key].strip()\n",
    "                results_final.append(r)\n",
    "                \n",
    "        results = results_final\n",
    "    results_final = []\n",
    "    for result in results:\n",
    "        r_final = result\n",
    "        r_final['guidance'] = 4\n",
    "        r_final['rank'] = 4\n",
    "        r_final['action'] = 'enhance'\n",
    "        r_final['resolution'] = 512\n",
    "        r_final['dynamic_resolution'] = False\n",
    "        r_final['batch_size'] = 1\n",
    "        results_final.append(r_final)\n",
    "    if file_name_to_save is None:\n",
    "        if name is None:\n",
    "            file_name_to_save = 'custom-prompts-GPT.yaml'\n",
    "        else:\n",
    "            file_name_to_save = f'prompts-{name}.yaml'\n",
    "    if save:\n",
    "        with open(f'trainscripts/textsliders/data/{file_name_to_save}', 'w+') as f:\n",
    "            yaml.dump(results_final, f, allow_unicode=True, sort_keys=False)\n",
    "        if verbose:\n",
    "            print(f'Prompt file saved to: \"trainscripts/textsliders/data/{file_name_to_save}\"')\n",
    "        return f'trainscripts/textsliders/data/{file_name_to_save}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "eae07933-c3d7-43d1-99ba-b87b749ec1fd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Target: person\n",
      "Positive: person, aged, wrinkles, grey hair, elderly features\n",
      "Negative: person, youthful, smooth skin, vibrant, young\n",
      "Preservation: white, black, indian, asian, hispanic ; male, female\n",
      "Name: person_age_GPT\n",
      "Prompt file saved to: \"trainscripts/textsliders/data/prompts-person_age_GPT.yaml\"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'trainscripts/textsliders/data/prompts-person_age_GPT.yaml'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query = \"I want to build a slider to make people old\"\n",
    "generate_prompts_sliders(slider_query=query, model=\"gpt-4-turbo-preview\", save=True, verbose=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
